# Copyright 2024 The Bazel Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A pip_parse implementation for version aware toolchains in WORKSPACE."""

load("//python/private:text_util.bzl", "render")
load(":pip_repository.bzl", pip_parse = "pip_repository")

def _multi_pip_parse_impl(rctx):
    rules_python = rctx.attr._rules_python_workspace.workspace_name
    load_statements = []
    install_deps_calls = []
    process_requirements_calls = []
    for python_version, pypi_repository in rctx.attr.pip_parses.items():
        sanitized_python_version = python_version.replace(".", "_")
        load_statement = """\
load(
    "@{pypi_repository}//:requirements.bzl",
    _{sanitized_python_version}_install_deps = "install_deps",
    _{sanitized_python_version}_all_requirements = "all_requirements",
)""".format(
            pypi_repository = pypi_repository,
            sanitized_python_version = sanitized_python_version,
        )
        load_statements.append(load_statement)
        process_requirements_call = """\
_process_requirements(
    pkg_labels = _{sanitized_python_version}_all_requirements,
    python_version = "{python_version}",
    repo_prefix = "{pypi_repository}_",
)""".format(
            pypi_repository = pypi_repository,
            python_version = python_version,
            sanitized_python_version = sanitized_python_version,
        )
        process_requirements_calls.append(process_requirements_call)
        install_deps_call = """    _{sanitized_python_version}_install_deps(**whl_library_kwargs)""".format(
            sanitized_python_version = sanitized_python_version,
        )
        install_deps_calls.append(install_deps_call)

    # NOTE @aignas 2023-10-31: I am not sure it is possible to render aliases
    # for all of the packages using the `render_pkg_aliases` function because
    # we need to know what the list of packages for each version is and then
    # we would be creating directories for each.
    macro_tmpl = "@%s_{}//:{}" % rctx.attr.name

    requirements_bzl = """\
# Generated by python/pip.bzl

load("@{rules_python}//python:pip.bzl", "whl_library_alias", "pip_utils")
{load_statements}

_wheel_names = []
_version_map = dict()
def _process_requirements(pkg_labels, python_version, repo_prefix):
    for pkg_label in pkg_labels:
        wheel_name = Label(pkg_label).package
        if not wheel_name:
            # We are dealing with the cases where we don't have aliases.
            workspace_name = Label(pkg_label).workspace_name
            wheel_name = workspace_name[len(repo_prefix):]

        _wheel_names.append(wheel_name)
        if not wheel_name in _version_map:
            _version_map[wheel_name] = dict()
        _version_map[wheel_name][python_version] = repo_prefix

{process_requirements_calls}

def requirement(name):
    return "{macro_tmpl}".format(pip_utils.normalize_name(name), "pkg")

def whl_requirement(name):
    return "{macro_tmpl}".format(pip_utils.normalize_name(name), "whl")

def data_requirement(name):
    return "{macro_tmpl}".format(pip_utils.normalize_name(name), "data")

def dist_info_requirement(name):
    return "{macro_tmpl}".format(pip_utils.normalize_name(name), "dist_info")

def install_deps(**whl_library_kwargs):
{install_deps_calls}
    for wheel_name in _wheel_names:
        whl_library_alias(
            name = "{name}_" + wheel_name,
            wheel_name = wheel_name,
            default_version = "{default_version}",
            minor_mapping = {minor_mapping},
            version_map = _version_map[wheel_name],
        )
""".format(
        name = rctx.attr.name,
        install_deps_calls = "\n".join(install_deps_calls),
        load_statements = "\n".join(load_statements),
        macro_tmpl = macro_tmpl,
        process_requirements_calls = "\n".join(process_requirements_calls),
        rules_python = rules_python,
        default_version = rctx.attr.default_version,
        minor_mapping = render.indent(render.dict(rctx.attr.minor_mapping)).lstrip(),
    )
    rctx.file("requirements.bzl", requirements_bzl)
    rctx.file("BUILD.bazel", "exports_files(['requirements.bzl'])")

_multi_pip_parse = repository_rule(
    _multi_pip_parse_impl,
    attrs = {
        "default_version": attr.string(),
        "minor_mapping": attr.string_dict(),
        "pip_parses": attr.string_dict(),
        "_rules_python_workspace": attr.label(default = Label("//:WORKSPACE")),
    },
)

def multi_pip_parse(name, default_version, python_versions, python_interpreter_target, requirements_lock, minor_mapping, **kwargs):
    """NOT INTENDED FOR DIRECT USE!

    This is intended to be used by the multi_pip_parse implementation in the template of the
    multi_toolchain_aliases repository rule.

    Args:
        name: the name of the multi_pip_parse repository.
        default_version: {type}`str` the default Python version.
        python_versions: {type}`list[str]` all Python toolchain versions currently registered.
        python_interpreter_target: {type}`dict[str, Label]` a dictionary which keys are Python versions and values are resolved host interpreters.
        requirements_lock: {type}`dict[str, Label]` a dictionary which keys are Python versions and values are locked requirements files.
        minor_mapping: {type}`dict[str, str]` mapping between `X.Y` to `X.Y.Z` format.
        **kwargs: extra arguments passed to all wrapped pip_parse.

    Returns:
        The internal implementation of multi_pip_parse repository rule.
    """
    pip_parses = {}
    for python_version in python_versions:
        if not python_version in python_interpreter_target:
            fail("Missing python_interpreter_target for Python version %s in '%s'" % (python_version, name))
        if not python_version in requirements_lock:
            fail("Missing requirements_lock for Python version %s in '%s'" % (python_version, name))

        pip_parse_name = name + "_" + python_version.replace(".", "_")
        pip_parse(
            name = pip_parse_name,
            python_interpreter_target = python_interpreter_target[python_version],
            requirements_lock = requirements_lock[python_version],
            **kwargs
        )
        pip_parses[python_version] = pip_parse_name

    return _multi_pip_parse(
        name = name,
        default_version = default_version,
        pip_parses = pip_parses,
        minor_mapping = minor_mapping,
    )
